백준 4149번

큰 수 소인수분해

Posted by ChoiCube84 on July 27, 2024 · 74 mins read

백준 4149번

오늘 풀어본 문제는 백준의 4149번 문제1이다. 문제 풀이에 사용한 언어는 C++ 이다.

solved.ac 기준 CLASS

CLASS 8

문제 정보

이 문제의 내용과 조건은 다음과 같다.

문제

큰 수를 소인수분해 해보자.

입력

입력은 한 줄로 이루어져 있고, 소인수분해 해야 하는 수가 주어진다. 이 수는 $1$ 보다 크고, $2^{62}$ 보다 작다.

출력

입력으로 주어진 양의 정수를 소인수분해 한 뒤, 모든 인수를 한 줄에 하나씩 증가하는 순서로 출력한다.

풀이과정

1번째 시도

소인수분해를 위한 알고리즘을 찾아보다가, 폴라드 로 알고리즘이라는 것을 알게되었다.

폴라드 로 알고리즘을 구현하기 위해서는 어떤 수가 소수인지 판별할 수 있어야 했는데, 밀러-라빈 소수 판별법이라는 것이 있었다.

밀러-라빈 소수 판별법을 구현하기 위해서는 큰 수의 곱셈을 진행할 수 있어야했는데, 효율적인 곱셈을 위해서는 고속 푸리에 변환 (FFT) 를 이용해야 했다.

그래서 예전에 만들어둔 FFT 알고리즘을 손보고 백준 15576번 문제인 큰 수 곱셈 (2) 문제를 풀며 큰 수들의 곱셈을 할 수 있는 함수를 만들었다.

만들어진 함수를 기반으로 밀러-라빈 소수 판별법을 구현했고, 마찬가지로 폴라드 로 알고리즘을 구현할 수 있었다.

코드는 다음과 같이 작성하였다. 제출할 때는 헤더파일의 코드를 붙여넣어 제출하였다.

custom_algorithms.hpp

#ifndef __CUSTOM_ALGORITHMS_HPP__
#define __CUSTOM_ALGORITHMS_HPP__

#include <cmath>
#include <vector>
#include <complex>
#include <string>
#include <random>
#include <numeric>

namespace custom_algorithms {
    namespace fft {
        long double const_pi(void) {
            return std::atan(1) * 4;
        }

        void FFT(std::vector<std::complex<long double>>& a, const std::complex<long double>& w) {
            size_t n = a.size();

            if (n == 1) {
                return;
            }

            std::vector<std::complex<long double>> a_even(n/2), a_odd(n/2);

            for (size_t i=0; i<(n/2); i++) {
                a_even[i] = a[2 * i];
                a_odd[i] = a[2 * i + 1];
            }

            std::complex<long double> w_squared = w * w;

            FFT(a_even, w_squared);
            FFT(a_odd, w_squared);

            std::complex<long double> w_i = 1;

            for (size_t i=0; i<(n/2); i++) {
                a[i] = a_even[i] + w_i * a_odd[i];
                a[i + (n/2)] = a_even[i] - w_i * a_odd[i];

                w_i *= w;
            }
        }

        std::vector<std::complex<long double>> convolution(std::vector<std::complex<long double>> a, std::vector<std::complex<long double>> b, bool getIntegerResult = false) {
            size_t n = 1;
            long double pi = const_pi();

            while (n <= a.size() || n <= b.size()) {
                n <<= 1;
            }
            n <<= 1;

            a.resize(n);
            b.resize(n);

            std::vector<std::complex<long double>> c(n);

            std::complex<long double> w(cos(2 * pi / n), sin(2 * pi / n));

            FFT(a, w);
            FFT(b, w);

            for (int i = 0; i < n; i++) {
                c[i] = a[i] * b[i];
            }

            FFT(c, std::complex<long double>(w.real(), -w.imag()));

            for (int i = 0; i < n; i++) {
                c[i] /= std::complex<long double>(n, 0);
                if (getIntegerResult) {
                    c[i] = std::complex<long double>(round(c[i].real()), round(c[i].imag()));
                }
            }

            return c;
        }

        template <typename T>
        std::vector<T> stringToVector(const std::string& str) {
            std::vector<T> result(str.size());

            for (size_t i=0; i<str.size(); i++) {
                result[i] = static_cast<T>(str[i] - '0');
            }

            return result;
        }

        template <typename T>
        std::string vectorToString(const std::vector<T>& vec) {
            for (size_t i=vec.size()-1; i>0; i--) {
                vec[i-1] += (vec[i] / 10);
                vec[i] %= 10;
            }

            std::string result;

            for (auto& digit : vec) {
                result += static_cast<char>(digit + '0');
            }

            return result;
        }

        template <typename T>
        std::string fastMultiplication(const T& A, const T& B) {
            return fastMultiplication(std::to_string(A), std::to_string(B));
        }

        template <>
        std::string fastMultiplication(const std::string& A, const std::string& B) {
            std::vector<int> a = stringToVector<int>(A);
            std::vector<int> b = stringToVector<int>(B);

            size_t n = a.size() + b.size() - 1;

            std::vector<std::complex<long double>> a_complex(a.begin(), a.end());
            std::vector<std::complex<long double>> b_complex(b.begin(), b.end());

            std::vector<std::complex<long double>> conv = convolution(a_complex, b_complex, true);
            std::vector<int> digitArray(n, 0);
            
            for (size_t i=0; i<n; i++) {
                digitArray[i] = static_cast<int>(conv[i].real());
            }

            for (int i=digitArray.size()-1; i>0; i--) {
                digitArray[i-1] += (digitArray[i] / 10);
                digitArray[i] %= 10;
            }
            
            std::string result;
            for (auto& digit : digitArray) {
                result += std::to_string(digit);
            }
            
            return result;
        }
    }

    namespace common {
        template <typename T>
        T stoiWithMOD(const std::string& s, const T& MOD=static_cast<T>(0)) {
            T result = static_cast<T>(0);

            for (auto& c : s) {
                result *= 10;

                if (MOD != 0) {
                    result %= MOD;
                }

                T added = static_cast<T>(c - '0');
                if (MOD != 0) {
                    added %= MOD;
                }

                result += added;

                if (MOD != 0) {
                    result %= MOD;
                }
            }
            return result;
        }

        template <typename T>
        T power(const T& a, const T& b, const T& MOD=static_cast<T>(0)){
            T result = static_cast<T>(1);

            std::string (*mult)(const T&, const T&) = fft::fastMultiplication<T>;

            T base = a;
            T exponent = b;

            if (MOD != 0) {
                base %= MOD;
            }

            while (exponent) {
                if (exponent % 2 == 1) {
                    result = stoiWithMOD(mult(result, base), MOD);
                }

                base = stoiWithMOD(mult(base, base), MOD);
                exponent >>= 1;
            }

            return result;
        }
    }

    namespace miller_rabin {
        std::vector<int> basicPrimes = {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37};

        bool isComposite(unsigned long long int a, unsigned long long int n) {
            unsigned long long int k = n - 1;

            while (true) {
                unsigned long long int d = common::power(a, k, n);

                if (k % 2 == 1) {
                    return (d != 1 && d != n - 1);
                }
                else if (d == n - 1) {
                    return false;
                }

                k /= 2;
            }
        }

        bool isPrime(unsigned long long int n) {
            if (n <= 1) {
                return false;
            }

            for (auto& prime : basicPrimes){
                if (n == prime) {
                    return true;
                }
                else if (n % prime == 0) {
                    return false;
                }
            }

            for (auto& prime : basicPrimes) {
                if (isComposite(prime, n)) {
                    return false;
                }
            }

            return true;
        }
    }

    namespace pollard_rho {
        unsigned long long int findFactor(unsigned long long int n) {
            static std::mt19937_64 mt(std::random_device{}());

            static std::uniform_int_distribution<unsigned long long int> dist1(2, n);
            static std::uniform_int_distribution<unsigned long long int> dist2(1, n);

            std::string (*mult)(const unsigned long long int&, const unsigned long long int&) = fft::fastMultiplication<unsigned long long int>;

            if (n == 1) {
                return 1;
            }
            else if (n % 2 == 0) {
                return 2;
            }
            else if (miller_rabin::isPrime(n)) {
                return n;
            }
            else {
                unsigned long long int x = dist1(mt);
                unsigned long long int y = x;

                unsigned long long int c = dist2(mt);
                unsigned long long int d = 1;

                while (d == 1) {
                    x = (common::stoiWithMOD(mult(x, x), n) + c) % n;

                    y = (common::stoiWithMOD(mult(y, y), n) + c) % n;
                    y = (common::stoiWithMOD(mult(y, y), n) + c) % n;

                    d = std::gcd(n, (x > y ? x - y : y - x));

                    if (d == n) {
                        return findFactor(n);
                    }
                }

                if (miller_rabin::isPrime(d)) {
                    return d;
                }
                else {
                    return findFactor(d);
                }
            }
        }

        std::vector<std::pair<unsigned long long int, unsigned long long int>> factorize(unsigned long long int n) {
            std::vector<std::pair<unsigned long long int, unsigned long long int>> result;

            struct cmp {
                bool operator()(const std::pair<unsigned long long int, unsigned long long int>& a, const std::pair<unsigned long long int, unsigned long long int>& b) {
                    return a.first < b.first;
                }
            };

            while (n > 1) {
                unsigned long long int factor = findFactor(n);

                n /= factor;
                result.emplace_back(std::make_pair(factor, 1));

                while (n % factor == 0) {
                    n /= factor;
                    result.back().second++;
                }
            }

            std::sort(result.begin(), result.end(), cmp());

            return result;
        }
    }
}

#endif // __CUSTOM_ALGORITHMS_HPP__

main.hpp

#include <bits/extc++.h>

using namespace __gnu_pbds;
using namespace std;

using namespace custom_algorithms;

using ll = long long int;
using ull = unsigned long long int;

using pll = pair<ll, ll>;

int main(void) {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);

    ull n;
    cin >> n;

    vector<pair<ull, ull>> result = pollard_rho::factorize(n);

    for (auto& [key, value] : result) {
        for (ull i=0; i<value; i++) {
            cout << key << '\n';
        }
    }

    return 0;
}

실행 결과 ‘시간 초과’ 가 떴다.

2번째 시도

질문 게시판을 떠돌아다니며 여러 반례들을 테스트하며 함수들을 뜯어본 결과, std::string 자료형으로 저장된 수를 특정 값으로 나눈 나머지로 바꿔주는 stoiWithMOD 함수에서 result *= 10 부분에서 오버플로우가 일어나 계산이 잘못되어 결과를 알아낼 수 없을 수 있겠다는 결론에 도달하였다.

그래서 바로 $10$ 을 곱하는 대신 원래 수에 $2$ 를 곱하고 MOD 값으로 나머지 연산을 취해준 뒤, 그 수에 $2$ 를 곱하고 MOD 값으로 나머지 연산을 취하는 과정을 두 번 더 진행한 값을 원래 수와 더한 뒤 MOD 값으로 나머지 연산을 취해주는 방식으로 $10$ 을 곱하고 MOD 값으로 나머지 연산을 취하는 것을 구현하였다.

시간은 조금 걸려도 게시판의 반례들을 전부 통과할 수 있었기 때문에 코드가 제대로 동작한다는 것을 알 수 있었다.

코드는 헤더파일만 다음과 같이 수정하였다.

custom_algorithms.hpp

#ifndef __CUSTOM_ALGORITHMS_HPP__
#define __CUSTOM_ALGORITHMS_HPP__

#include <cmath>
#include <vector>
#include <complex>
#include <string>
#include <random>
#include <numeric>

namespace custom_algorithms {
    namespace fft {
        long double const_pi(void) {
            return std::atan(1) * 4;
        }

        void FFT(std::vector<std::complex<long double>>& a, const std::complex<long double>& w) {
            size_t n = a.size();

            if (n == 1) {
                return;
            }

            std::vector<std::complex<long double>> a_even(n/2), a_odd(n/2);

            for (size_t i=0; i<(n/2); i++) {
                a_even[i] = a[2 * i];
                a_odd[i] = a[2 * i + 1];
            }

            std::complex<long double> w_squared = w * w;

            FFT(a_even, w_squared);
            FFT(a_odd, w_squared);

            std::complex<long double> w_i = 1;

            for (size_t i=0; i<(n/2); i++) {
                a[i] = a_even[i] + w_i * a_odd[i];
                a[i + (n/2)] = a_even[i] - w_i * a_odd[i];

                w_i *= w;
            }
        }

        std::vector<std::complex<long double>> convolution(std::vector<std::complex<long double>> a, std::vector<std::complex<long double>> b, bool getIntegerResult = false) {
            size_t n = 1;
            long double pi = const_pi();

            while (n <= a.size() || n <= b.size()) {
                n <<= 1;
            }
            n <<= 1;

            a.resize(n);
            b.resize(n);

            std::vector<std::complex<long double>> c(n);

            std::complex<long double> w(cos(2 * pi / n), sin(2 * pi / n));

            FFT(a, w);
            FFT(b, w);

            for (int i = 0; i < n; i++) {
                c[i] = a[i] * b[i];
            }

            FFT(c, std::complex<long double>(w.real(), -w.imag()));

            for (int i = 0; i < n; i++) {
                c[i] /= std::complex<long double>(n, 0);
                if (getIntegerResult) {
                    c[i] = std::complex<long double>(round(c[i].real()), round(c[i].imag()));
                }
            }

            return c;
        }

        template <typename T>
        std::vector<T> stringToVector(const std::string& str) {
            std::vector<T> result(str.size());

            for (size_t i=0; i<str.size(); i++) {
                result[i] = static_cast<T>(str[i] - '0');
            }

            return result;
        }

        template <typename T>
        std::string vectorToString(const std::vector<T>& vec) {
            for (size_t i=vec.size()-1; i>0; i--) {
                vec[i-1] += (vec[i] / 10);
                vec[i] %= 10;
            }

            std::string result;

            for (auto& digit : vec) {
                result += static_cast<char>(digit + '0');
            }

            return result;
        }

        template <typename T>
        std::string fastMultiplication(const T& A, const T& B) {
            return fastMultiplication(std::to_string(A), std::to_string(B));
        }

        template <>
        std::string fastMultiplication(const std::string& A, const std::string& B) {
            std::vector<int> a = stringToVector<int>(A);
            std::vector<int> b = stringToVector<int>(B);

            size_t n = a.size() + b.size() - 1;

            std::vector<std::complex<long double>> a_complex(a.begin(), a.end());
            std::vector<std::complex<long double>> b_complex(b.begin(), b.end());

            std::vector<std::complex<long double>> conv = convolution(a_complex, b_complex, true);
            std::vector<int> digitArray(n, 0);
            
            for (size_t i=0; i<n; i++) {
                digitArray[i] = static_cast<int>(conv[i].real());
            }

            for (int i=digitArray.size()-1; i>0; i--) {
                digitArray[i-1] += (digitArray[i] / 10);
                digitArray[i] %= 10;
            }
            
            std::string result;
            for (auto& digit : digitArray) {
                result += std::to_string(digit);
            }
            
            return result;
        }
    }

    namespace common {
        template <typename T>
        T stoiWithMOD(const std::string& s, const T& MOD=static_cast<T>(0)) {
            T result = static_cast<T>(0);

            for (auto& c : s) {
                T temp = result;

                result *= 2;
                temp *= 2;

                if (MOD != 0) {
                    result %= MOD;
                    temp %= MOD;
                }

                temp *= 2;

                if (MOD != 0) {
                    temp %= MOD;
                }

                temp *= 2;

                if (MOD != 0) {
                    temp %= MOD;
                }

                result += temp;

                if (MOD != 0) {
                    result %= MOD;
                }

                T added = static_cast<T>(c - '0');
                if (MOD != 0) {
                    added %= MOD;
                }

                result += added;

                if (MOD != 0) {
                    result %= MOD;
                }
            }
            return result;
        }

        template <typename T>
        T power(const T& a, const T& b, const T& MOD=static_cast<T>(0)){
            T result = static_cast<T>(1);

            std::string (*mult)(const T&, const T&) = fft::fastMultiplication<T>;

            T base = a;
            T exponent = b;

            if (MOD != 0) {
                base %= MOD;
            }

            while (exponent) {
                if (exponent % 2 == 1) {
                    result = stoiWithMOD(mult(result, base), MOD);
                }

                base = stoiWithMOD(mult(base, base), MOD);
                exponent >>= 1;
            }

            return result;
        }
    }

    namespace miller_rabin {
        std::vector<int> basicPrimes = {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37};

        bool isComposite(unsigned long long int a, unsigned long long int n) {
            unsigned long long int k = n - 1;

            while (true) {
                unsigned long long int d = common::power(a, k, n);

                if (k % 2 == 1) {
                    return (d != 1 && d != n - 1);
                }
                else if (d == n - 1) {
                    return false;
                }

                k /= 2;
            }
        }

        bool isPrime(unsigned long long int n) {
            if (n <= 1) {
                return false;
            }

            for (auto& prime : basicPrimes){
                if (n == prime) {
                    return true;
                }
                else if (n % prime == 0) {
                    return false;
                }
            }

            for (auto& prime : basicPrimes) {
                if (isComposite(prime, n)) {
                    return false;
                }
            }

            return true;
        }
    }

    namespace pollard_rho {
        unsigned long long int findFactor(unsigned long long int n) {
            static std::mt19937_64 mt(std::random_device{}());

            static std::uniform_int_distribution<unsigned long long int> dist1(2, n);
            static std::uniform_int_distribution<unsigned long long int> dist2(1, n);

            std::string (*mult)(const unsigned long long int&, const unsigned long long int&) = fft::fastMultiplication<unsigned long long int>;

            if (n == 1) {
                return 1;
            }
            else if (n % 2 == 0) {
                return 2;
            }
            else if (miller_rabin::isPrime(n)) {
                return n;
            }
            else {
                unsigned long long int x = dist1(mt);
                unsigned long long int y = x;

                unsigned long long int c = dist2(mt);
                unsigned long long int d = 1;

                while (d == 1) {
                    x = (common::stoiWithMOD(mult(x, x), n) + c) % n;

                    y = (common::stoiWithMOD(mult(y, y), n) + c) % n;
                    y = (common::stoiWithMOD(mult(y, y), n) + c) % n;

                    d = std::gcd(n, (x > y ? x - y : y - x));

                    if (d == n) {
                        return findFactor(n);
                    }
                }

                if (miller_rabin::isPrime(d)) {
                    return d;
                }
                else {
                    return findFactor(d);
                }
            }
        }

        std::vector<std::pair<unsigned long long int, unsigned long long int>> factorize(unsigned long long int n) {
            std::vector<std::pair<unsigned long long int, unsigned long long int>> result;

            struct cmp {
                bool operator()(const std::pair<unsigned long long int, unsigned long long int>& a, const std::pair<unsigned long long int, unsigned long long int>& b) {
                    return a.first < b.first;
                }
            };

            while (n > 1) {
                unsigned long long int factor = findFactor(n);

                n /= factor;
                result.emplace_back(std::make_pair(factor, 1));

                while (n % factor == 0) {
                    n /= factor;
                    result.back().second++;
                }
            }

            std::sort(result.begin(), result.end(), cmp());

            return result;
        }
    }

    namespace floyd_warshall {
        template <template<typename, typename> typename Table, typename Node, typename Distance>
        Table<Node, Table<Node, Distance>> getShortestPath(const Table<Node, Table<Node, Distance>>& graph) {
            Table<Node, Table<Node, Distance>> distance = graph;

            for (auto [middle, _] : distance) {
                for (auto [start, _] : distance) {
                    for (auto [end, _] : distance) {
                        if (distance[start][end] > distance[start][middle] + distance[middle][end]) {
                            distance[start][end] = distance[start][middle] + distance[middle][end];
                        }
                    }
                }
            }

            return distance;
        }
    }
}

#endif // __CUSTOM_ALGORITHMS_HPP__

실행 결과 ‘시간 초과’ 가 떴다.

3번째 시도

이번에 난 시간 초과는 다른 함수들이 잘못 구현되어 폴라드 로 함수에서 무한루프에 빠진 것이 아닌, 그냥 코드 자체가 느려서 시간 초과를 받은 것이라고 추측했다. 그래서 큰 수의 곱셈을 위해 궁극의 무기인 __int128 자료형을 사용하기로 하였다.

코드는 헤더파일만 다음과 같이 수정하였다.

custom_algorithms.hpp

#ifndef __CUSTOM_ALGORITHMS_HPP__
#define __CUSTOM_ALGORITHMS_HPP__

#include <cmath>
#include <vector>
#include <complex>
#include <string>
#include <random>
#include <numeric>

namespace custom_algorithms {
    namespace fft {
        long double const_pi(void) {
            return std::atan(1) * 4;
        }

        void FFT(std::vector<std::complex<long double>>& a, const std::complex<long double>& w) {
            size_t n = a.size();

            if (n == 1) {
                return;
            }

            std::vector<std::complex<long double>> a_even(n/2), a_odd(n/2);

            for (size_t i=0; i<(n/2); i++) {
                a_even[i] = a[2 * i];
                a_odd[i] = a[2 * i + 1];
            }

            std::complex<long double> w_squared = w * w;

            FFT(a_even, w_squared);
            FFT(a_odd, w_squared);

            std::complex<long double> w_i = 1;

            for (size_t i=0; i<(n/2); i++) {
                a[i] = a_even[i] + w_i * a_odd[i];
                a[i + (n/2)] = a_even[i] - w_i * a_odd[i];

                w_i *= w;
            }
        }

        std::vector<std::complex<long double>> convolution(std::vector<std::complex<long double>> a, std::vector<std::complex<long double>> b, bool getIntegerResult = false) {
            size_t n = 1;
            long double pi = const_pi();

            while (n <= a.size() || n <= b.size()) {
                n <<= 1;
            }
            n <<= 1;

            a.resize(n);
            b.resize(n);

            std::vector<std::complex<long double>> c(n);

            std::complex<long double> w(cos(2 * pi / n), sin(2 * pi / n));

            FFT(a, w);
            FFT(b, w);

            for (int i = 0; i < n; i++) {
                c[i] = a[i] * b[i];
            }

            FFT(c, std::complex<long double>(w.real(), -w.imag()));

            for (int i = 0; i < n; i++) {
                c[i] /= std::complex<long double>(n, 0);
                if (getIntegerResult) {
                    c[i] = std::complex<long double>(round(c[i].real()), round(c[i].imag()));
                }
            }

            return c;
        }

        template <typename T>
        std::vector<T> stringToVector(const std::string& str) {
            std::vector<T> result(str.size());

            for (size_t i=0; i<str.size(); i++) {
                result[i] = static_cast<T>(str[i] - '0');
            }

            return result;
        }

        template <typename T>
        std::string vectorToString(const std::vector<T>& vec) {
            for (size_t i=vec.size()-1; i>0; i--) {
                vec[i-1] += (vec[i] / 10);
                vec[i] %= 10;
            }

            std::string result;

            for (auto& digit : vec) {
                result += static_cast<char>(digit + '0');
            }

            return result;
        }

        template <typename T>
        std::string fastMultiplication(const T& A, const T& B) {
            return fastMultiplication(std::to_string(A), std::to_string(B));
        }

        template <>
        std::string fastMultiplication(const std::string& A, const std::string& B) {
            std::vector<int> a = stringToVector<int>(A);
            std::vector<int> b = stringToVector<int>(B);

            size_t n = a.size() + b.size() - 1;

            std::vector<std::complex<long double>> a_complex(a.begin(), a.end());
            std::vector<std::complex<long double>> b_complex(b.begin(), b.end());

            std::vector<std::complex<long double>> conv = convolution(a_complex, b_complex, true);
            std::vector<int> digitArray(n, 0);

            for (size_t i=0; i<n; i++) {
                digitArray[i] = static_cast<int>(conv[i].real());
            }

            for (int i=digitArray.size()-1; i>0; i--) {
                digitArray[i-1] += (digitArray[i] / 10);
                digitArray[i] %= 10;
            }

            std::string result;
            for (auto& digit : digitArray) {
                result += std::to_string(digit);
            }

            return result;
        }
    }

    namespace common {
        template <typename T>
        T stoiWithMOD(const std::string& s, const T& MOD=static_cast<T>(0)) {
            T result = static_cast<T>(0);

            for (auto& c : s) {
                result *= 2;

                if (MOD != 0) {
                    result %= MOD;
                }

                T temp = result;

                temp *= 2;

                if (MOD != 0) {
                    temp %= MOD;
                }

                temp *= 2;

                if (MOD != 0) {
                    temp %= MOD;
                }

                result += temp;

                if (MOD != 0) {
                    result %= MOD;
                }

                T added = static_cast<T>(c - '0');
                if (MOD != 0) {
                    added %= MOD;
                }

                result += added;

                if (MOD != 0) {
                    result %= MOD;
                }
            }
            return result;
        }

        template <typename T>
        T multWithMOD_int128(const T& a, const T& b, const T& MOD=static_cast<T>(0)) {
            __int128 result = a;

            result *= static_cast<__int128>(b);

            if (MOD != 0) {
                result %= MOD;
            }

            return result;
        }

        template <typename T>
        T power(const T& a, const T& b, const T& MOD=static_cast<T>(0), bool useInt128 = true){
            T result = static_cast<T>(1);

            std::string (*mult)(const T&, const T&) = fft::fastMultiplication<T>;

            T base = a;
            T exponent = b;

            if (MOD != 0) {
                base %= MOD;
            }

            while (exponent) {
                if (exponent % 2 == 1) {
                    if (!useInt128) {
                        result = stoiWithMOD(mult(result, base), MOD);
                    }
                    else {
                        result = multWithMOD_int128(result, base, MOD);
                    }
                }

                if (!useInt128) {
                    base = stoiWithMOD(mult(base, base), MOD);
                }
                else {
                    base = multWithMOD_int128(base, base, MOD);
                }

                exponent >>= 1;
            }

            return result;
        }
    }

    namespace miller_rabin {
        std::vector<int> basicPrimes = {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37};

        bool isComposite(unsigned long long int a, unsigned long long int n, bool useInt128 = true) {
            unsigned long long int k = n - 1;

            while (true) {
                unsigned long long int d = common::power(a, k, n, useInt128);

                if (k % 2 == 1) {
                    return (d != 1 && d != n - 1);
                }
                else if (d == n - 1) {
                    return false;
                }

                k /= 2;
            }
        }

        bool isPrime(unsigned long long int n, bool useInt128 = true) {
            if (n <= 1) {
                return false;
            }

            for (auto& prime : basicPrimes){
                if (n == prime) {
                    return true;
                }
                else if (n % prime == 0) {
                    return false;
                }
            }

            for (auto& prime : basicPrimes) {
                if (isComposite(prime, n, useInt128)) {
                    return false;
                }
            }

            return true;
        }
    }

    namespace pollard_rho {
        unsigned long long int findFactor(unsigned long long int n, bool useInt128 = true) {
            static std::mt19937_64 mt(std::random_device{}());

            static std::uniform_int_distribution<unsigned long long int> dist1(2, n);
            static std::uniform_int_distribution<unsigned long long int> dist2(1, n);

            std::string (*mult)(const unsigned long long int&, const unsigned long long int&) = fft::fastMultiplication<unsigned long long int>;

            if (n == 1) {
                return 1;
            }
            else if (n % 2 == 0) {
                return 2;
            }
            else if (miller_rabin::isPrime(n)) {
                return n;
            }
            else {
                unsigned long long int x = dist1(mt);
                unsigned long long int y = x;

                unsigned long long int c = dist2(mt);
                unsigned long long int d = 1;

                while (d == 1) {
                    if (!useInt128) {
                        x = (common::stoiWithMOD(mult(x, x), n) + c) % n;

                        y = (common::stoiWithMOD(mult(y, y), n) + c) % n;
                        y = (common::stoiWithMOD(mult(y, y), n) + c) % n;
                    }
                    else {
                        x = common::multWithMOD_int128(x, x, n) + c;

                        y = common::multWithMOD_int128(y, y, n) + c;
                        y = common::multWithMOD_int128(y, y, n) + c;
                    }

                    d = std::gcd(n, (x > y ? x - y : y - x));

                    if (d == n) {
                        return findFactor(n);
                    }
                }

                if (miller_rabin::isPrime(d, useInt128)) {
                    return d;
                }
                else {
                    return findFactor(d);
                }
            }
        }

        std::vector<std::pair<unsigned long long int, unsigned long long int>> factorize(unsigned long long int n, bool useInt128 = true) {
            std::vector<std::pair<unsigned long long int, unsigned long long int>> result;

            struct cmp {
                bool operator()(const std::pair<unsigned long long int, unsigned long long int>& a, const std::pair<unsigned long long int, unsigned long long int>& b) {
                    return a.first < b.first;
                }
            };

            while (n > 1) {
                unsigned long long int factor = findFactor(n, useInt128);

                n /= factor;
                result.emplace_back(std::make_pair(factor, 1));

                while (n % factor == 0) {
                    n /= factor;
                    result.back().second++;
                }
            }

            std::sort(result.begin(), result.end(), cmp());

            return result;
        }
    }

    namespace floyd_warshall {
        template <template<typename, typename> typename Table, typename Node, typename Distance>
        Table<Node, Table<Node, Distance>> getShortestPath(const Table<Node, Table<Node, Distance>>& graph) {
            Table<Node, Table<Node, Distance>> distance = graph;

            for (auto [middle, _] : distance) {
                for (auto [start, _] : distance) {
                    for (auto [end, _] : distance) {
                        if (distance[start][end] > distance[start][middle] + distance[middle][end]) {
                            distance[start][end] = distance[start][middle] + distance[middle][end];
                        }
                    }
                }
            }

            return distance;
        }
    }
}

#endif // __CUSTOM_ALGORITHMS_HPP__

그러자 모든 테스트 케이스를 통과하고 ‘맞았습니다’가 나오는 것을 확인할 수 있었다.

마무리

이 문제를 풀어내면서, 이제 나는 빠르고 효율적으로 소인수분해를 진행해주는 함수를 얻게 되었고 CLASS 8 문제를 미리 하나 해치워버렸다! 풀기 전까지는 클래스 문제인지 몰랐는데 조금 당황스러웠다…

사실 이 문제를 푼 것은 다른 문제를 풀기 위해 소인수분해를 구현하는 김에 푼 문제인데, 그 문제는 다음 글에 나올 것이다.

내가 만든 알고리즘 헤더파일은 내 레포지토리2에서 확인할 수 있다.

오늘의 PS는 아직 끝나지 않았다…!


1: https://www.acmicpc.net/problem/4149
2: https://github.com/ChoiCube84/baekjoon-solutions/blob/main/C%2B%2B/custom_algorithms.hpp